from tqdm import tqdm

import torch
import torch.nn.functional as F
import torch.nn as nn
from datasets.imagenet import ImageNet
from datasets import build_dataset
from datasets.utils import build_data_loader, AugMixAugmenter
import clip
import torchvision.transforms as transforms
from PIL import Image
import operator
import numpy as np

try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC

def cls_acc(output, target, topk=1):
    pred = output.topk(topk, 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    acc = float(correct[: topk].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
    acc = 100 * acc / target.shape[0]
    return acc

def softmax_entropy(x):
    return -(x.softmax(1) * x.log_softmax(1)).sum(1)

def avg_entropy(outputs):
    logits = outputs - outputs.logsumexp(dim=-1, keepdim=True)
    avg_logits = logits.logsumexp(dim=0) - np.log(logits.shape[0])
    min_real = torch.finfo(avg_logits.dtype).min
    avg_logits = torch.clamp(avg_logits, min=min_real)
    return -(avg_logits * torch.exp(avg_logits)).sum(dim=-1)


def clip_classifier(classnames, template, clip_model):
    with torch.no_grad():
        clip_weights = []

        for classname in classnames:
            # Tokenize the prompts
            classname = classname.replace('_', ' ')
            texts = [t.format(classname) for t in template]
            texts = clip.tokenize(texts).cuda()
            # prompt ensemble for ImageNet
            class_embeddings = clip_model.encode_text(texts)
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            clip_weights.append(class_embedding)

        clip_weights = torch.stack(clip_weights, dim=1).cuda()
    return clip_weights

def get_ood_preprocess():
    normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                std=[0.26862954, 0.26130258, 0.27577711])
    base_transform = transforms.Compose([
        transforms.Resize(224, interpolation=BICUBIC),
        transforms.CenterCrop(224)])
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        normalize])
    aug_preprocess = AugMixAugmenter(base_transform, preprocess, n_views=63, augmix=True)

    return aug_preprocess

def build_test_data_loader(dataset_name, root_path, preprocess):

    if dataset_name == 'I':

        dataset = ImageNet(root_path, preprocess)

        test_loader = torch.utils.data.DataLoader(dataset.test, batch_size=1, num_workers=8, shuffle=True)
        test_loader_sup = torch.utils.data.DataLoader(dataset.test, batch_size=1, num_workers=8, shuffle=False)

    elif dataset_name in ['A', 'V', 'R', 'S']:
        #preprocess = get_ood_preprocess()
        dataset = build_dataset(f"imagenet-{dataset_name.lower()}", root_path)
        test_loader = build_data_loader(data_source=dataset.test, batch_size=1, is_train=False, tfm=preprocess,
                                        shuffle=True)
        test_loader_sup = build_data_loader(data_source=dataset.test, batch_size=1, is_train=False, tfm=preprocess,
                                        shuffle=False)

    elif dataset_name in ['caltech101', 'dtd', 'eurosat', 'fgvc', 'food101', 'oxford_flowers', 'oxford_pets',
                          'stanford_cars', 'sun397', 'ucf101']:
        dataset = build_dataset(dataset_name, root_path)
        test_loader = build_data_loader(data_source=dataset.test, batch_size=1, is_train=False, tfm=preprocess,
                                        shuffle=True)
        test_loader_sup = build_data_loader(data_source=dataset.test, batch_size=1, is_train=False, tfm=preprocess,
                                        shuffle=False)

    else:
        raise "Dataset is not from the chosen list"

    return test_loader, dataset.classnames, dataset.template, test_loader_sup

def pre_load_features(cache_cfg, clip_model, loader, clip_weights, need_cache):
    # cache:{'pseudo_label':[features, loss]}
    features, labels, preds, pos_cache = [], [], [], {}
    cache_enabled= cache_cfg['enabled']
    if cache_enabled:
        cache_params = {k: cache_cfg[k] for k in ['shot_capacity']}

    with torch.no_grad():
        for i, (images, target) in enumerate(tqdm(loader)):
            target = target.cuda()

            image_features, clip_logits, pred, loss = get_clip_logits(images, clip_model, clip_weights)

            features.append(image_features)
            labels.append(target)
            preds.append(torch.tensor(pred).cuda())
            if need_cache == True:
                update_cache(pos_cache, pred, [image_features, loss], cache_params['shot_capacity'])

    features, labels, preds = torch.cat(features), torch.cat(labels), torch.tensor(preds)

    if need_cache == True:
        return features, labels, pos_cache
    else:
        return features, labels, None

def update_cache(cache, pred, features_loss, shot_capacity, include_prob_map=False):
    """Update cache with new features and loss, maintaining the maximum shot capacity."""
    with torch.no_grad():
        item = features_loss if not include_prob_map else features_loss[:2] + [features_loss[2]]

        if pred in cache:
            if len(cache[pred]) < shot_capacity:
                cache[pred].append(item)
            elif features_loss[1] < cache[pred][-1][1]:
            #elif features_loss[1] > cache[pred][-1][1]:
                cache[pred][-1] = item
            cache[pred] = sorted(cache[pred], key=operator.itemgetter(1))
        else:
            cache[pred] = [item]

def get_clip_logits(images, clip_model, clip_weights):
    with torch.no_grad():
        if isinstance(images, list):
            images = torch.cat(images, dim=0).cuda()
        else:
            images = images.cuda()

        image_features = clip_model.encode_image(images)
        image_features /= image_features.norm(dim=-1, keepdim=True)

        clip_logits = 100. * image_features @ clip_weights
        #num_classes = clip_weights.shape[-1]

        if image_features.size(0) > 1:
            batch_entropy = softmax_entropy(clip_logits)
            selected_idx = torch.argsort(batch_entropy, descending=False)[:int(batch_entropy.size()[0] * 0.1)]
            output = clip_logits[selected_idx]
            image_features = image_features[selected_idx].mean(0).unsqueeze(0)
            clip_logits = output.mean(0).unsqueeze(0)
            pred = int(output.mean(0).unsqueeze(0).topk(1, 1, True, True)[1].t())
            loss = avg_entropy(output)
            #loss = mutual_information(output, num_classes, pred, aug=True)
            #prob_map = output.softmax(1).mean(0).unsqueeze(0)

        else:
            loss = softmax_entropy(clip_logits)
            #prob_map = clip_logits.softmax(1)
            pred = int(clip_logits.topk(1, 1, True, True)[1].t()[0])
            #loss = mutual_information(clip_logits, num_classes, pred, aug=False)
        return image_features, clip_logits, pred, loss


